# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from collections import defaultdict
from typing import Any, Mapping, Optional, Sequence, Union

from datasets.builder import DatasetBuilder

from modelscope.hub.api import HubApi
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
from modelscope.utils.logger import get_logger
from .dataset_builder import MsCsvDatasetBuilder, TaskSpecificDatasetBuilder

logger = get_logger()


def format_dataset_structure(dataset_structure):
    return {
        k: v
        for k, v in dataset_structure.items()
        if (v.get('meta') or v.get('file'))
    }


def get_target_dataset_structure(dataset_structure: dict,
                                 subset_name: Optional[str] = None,
                                 split: Optional[str] = None):
    """
    Args:
        dataset_structure (dict): Dataset Structure, like
         {
            "default":{
                "train":{
                    "meta":"my_train.csv",
                    "file":"pictures.zip"
                }
            },
            "subsetA":{
                "test":{
                    "meta":"mytest.csv",
                    "file":"pictures.zip"
                }
            }
        }
        subset_name (str, optional): Defining the subset_name of the dataset.
        split (str, optional): Which split of the data to load.
    Returns:
           target_subset_name (str): Name of the chosen subset.
           target_dataset_structure (dict): Structure of the chosen split(s), like
           {
               "test":{
                        "meta":"mytest.csv",
                        "file":"pictures.zip"
                    }
            }
    """
    # verify dataset subset
    if (subset_name and subset_name not in dataset_structure) or (
            not subset_name and len(dataset_structure.keys()) > 1):
        raise ValueError(
            f'subset_name {subset_name} not found. Available: {dataset_structure.keys()}'
        )
    target_subset_name = subset_name
    if not subset_name:
        target_subset_name = next(iter(dataset_structure.keys()))
        logger.info(
            f'No subset_name specified, defaulting to the {target_subset_name}'
        )
    # verify dataset split
    target_dataset_structure = format_dataset_structure(
        dataset_structure[target_subset_name])
    if split and split not in target_dataset_structure:
        raise ValueError(
            f'split {split} not found. Available: {target_dataset_structure.keys()}'
        )
    if split:
        target_dataset_structure = {split: target_dataset_structure[split]}
    return target_subset_name, target_dataset_structure


def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool,
                         dataset_name: str, namespace: str,
                         version: str) -> list:
    """
    List all objects for specific dataset.

    Args:
        hub_api (class HubApi): HubApi instance.
        max_limit (int): Max number of objects.
        is_recursive (bool): Whether to list objects recursively.
        dataset_name (str): Dataset name.
        namespace (str): Namespace.
        version (str): Dataset version.
    Returns:
        res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...]
    """
    res = []
    objects = hub_api.list_oss_dataset_objects(
        dataset_name=dataset_name,
        namespace=namespace,
        max_limit=max_limit,
        is_recursive=is_recursive,
        is_filter_dir=True,
        revision=version)

    for item in objects:
        object_key = item.get('Key')
        res.append(object_key)

    return res


def contains_dir(file_map) -> bool:
    """
    To check whether input contains at least one directory.

    Args:
        file_map (dict): Structure of data files. e.g., {'train': 'train.zip', 'validation': 'val.zip'}
    Returns:
        True if input contains at least one directory, False otherwise.
    """
    res = False
    for k, v in file_map.items():
        if isinstance(v, str) and not v.endswith('.zip'):
            res = True
            break
    return res


def get_split_objects_map(file_map, objects):
    """
    Get the map between dataset split and oss objects.

    Args:
        file_map (dict): Structure of data files. e.g., {'train': 'train', 'validation': 'val'}, both of train and val
            are dirs.
        objects (list): List of oss objects. e.g., ['train/001/1_123.png', 'train/001/1_124.png', 'val/003/3_38.png']
    Returns:
        A map of split-objects. e.g., {'train': ['train/001/1_123.png', 'train/001/1_124.png'],
            'validation':['val/003/3_38.png']}
    """
    res = {}
    for k, v in file_map.items():
        res[k] = []

    for obj_key in objects:
        for k, v in file_map.items():
            if obj_key.startswith(v):
                res[k].append(obj_key)

    return res


def get_dataset_files(subset_split_into: dict,
                      dataset_name: str,
                      namespace: str,
                      revision: Optional[str] = DEFAULT_DATASET_REVISION):
    """
    Return:
        meta_map: Structure of meta files (.csv), the meta file name will be replaced by url, like
        {
           "test": "https://xxx/mytest.csv"
        }
        file_map: Structure of data files (.zip), like
        {
            "test": "pictures.zip"
        }
    """
    meta_map = defaultdict(dict)
    file_map = defaultdict(dict)
    args_map = defaultdict(dict)
    modelscope_api = HubApi()
    objects = list_dataset_objects(
        hub_api=modelscope_api,
        max_limit=-1,
        is_recursive=True,
        dataset_name=dataset_name,
        namespace=namespace,
        version=revision)

    for split, info in subset_split_into.items():
        meta_map[split] = modelscope_api.get_dataset_file_url(
            info.get('meta', ''), dataset_name, namespace, revision)
        if info.get('file'):
            file_map[split] = info['file']
        args_map[split] = info.get('args')

    if contains_dir(file_map):
        file_map = get_split_objects_map(file_map, objects)
    return meta_map, file_map, args_map


def load_dataset_builder(dataset_name: str, subset_name: str, namespace: str,
                         meta_data_files: Mapping[str, Union[str,
                                                             Sequence[str]]],
                         zip_data_files: Mapping[str, Union[str,
                                                            Sequence[str]]],
                         args_map: Mapping[str, Any], cache_dir: str,
                         version: Optional[Union[str]], split: Sequence[str],
                         **config_kwargs) -> DatasetBuilder:
    sub_dir = os.path.join(version, '_'.join(split))
    meta_data_file = next(iter(meta_data_files.values()))
    if not meta_data_file:
        args_map = next(iter(args_map.values()))
        if args_map is None:
            args_map = {}
        args_map.update(config_kwargs)
        builder_instance = TaskSpecificDatasetBuilder(
            dataset_name=dataset_name,
            namespace=namespace,
            cache_dir=cache_dir,
            subset_name=subset_name,
            meta_data_files=meta_data_files,
            zip_data_files=zip_data_files,
            hash=sub_dir,
            **args_map)
    elif meta_data_file.endswith('.csv'):
        builder_instance = MsCsvDatasetBuilder(
            dataset_name=dataset_name,
            namespace=namespace,
            cache_dir=cache_dir,
            subset_name=subset_name,
            meta_data_files=meta_data_files,
            zip_data_files=zip_data_files,
            hash=sub_dir)
    else:
        raise NotImplementedError(
            f'Dataset mete file extensions "{os.path.splitext(meta_data_file)[-1]}" is not implemented yet'
        )

    return builder_instance
